{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Forking\n",
    "ChromaDB now supports forking. Below is an example using forking to chunk and embed a github repo, fork off of the collection for a new github branch, and apply diffs to the new branch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install chromadb --quiet\n",
    "! pip install tree-sitter --quiet\n",
    "! pip install numpy --quiet\n",
    "! pip install tree-sitter-language-pack --quiet\n",
    "from tree_sitter import Language, Parser, Tree\n",
    "from tree_sitter_language_pack import get_language, get_parser\n",
    "import requests\n",
    "import base64\n",
    "import os\n",
    "import getpass\n",
    "from tqdm import tqdm\n",
    "import chromadb\n",
    "from chromadb.utils.embedding_functions import JinaEmbeddingFunction\n",
    "from chromadb.utils.results import query_result_to_dfs\n",
    "from chromadb.api.models.Collection import Collection\n",
    "from chromadb.api import ClientAPI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "PY_LANGUAGE = get_language(\"python\")\n",
    "REPO_OWNER = \"jairad26\"\n",
    "REPO_NAME = \"Django-WebApp\"\n",
    "EXISTING_BRANCH = \"main\"\n",
    "NEW_BRANCH = \"test1\"\n",
    "os.environ[\"GITHUB_API_KEY\"] = getpass.getpass(\"Github API Key:\")\n",
    "os.environ[\"CHROMA_JINA_API_KEY\"] = getpass.getpass(\"Jina API Key:\")\n",
    "os.environ[\"CHROMA_API_KEY\"] = getpass.getpass(\"Chroma API Key:\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Chunker and Github Helpers\n",
    "Below are 2 helper classes `CodeChunker` and `GithubRepoProcessor`\n",
    "\n",
    "`CodeChunker` is a custom tree-sitter implementation that allows you to chunk files into nodes and converts them into embeddable chunks.\n",
    "`GithubRepoProcessor` is a wrapper around the Github client to make it easier to pull file contents, and calculate diffs between branches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class CodeChunker:\n",
    "    def __init__(self, language: str = \"python\", max_chunk_size=500):\n",
    "        \"\"\"Initialize the Python code chunker.\n",
    "        \n",
    "        Args:\n",
    "            language: The programming language of the code\n",
    "            max_chunk_size: Maximum chunk size in bytes\n",
    "        \"\"\"\n",
    "        # Create a parser\n",
    "        self.parser = get_parser(language)\n",
    "        \n",
    "        # Set maximum chunk size (in bytes)\n",
    "        self.max_chunk_size = max_chunk_size\n",
    "        \n",
    "        # Define what node types we consider as \"chunkable\"\n",
    "        self.chunkable_types = {\n",
    "            \"function_definition\", \n",
    "            \"class_definition\",\n",
    "            \"if_statement\",\n",
    "            \"for_statement\",\n",
    "            \"while_statement\",\n",
    "            \"try_statement\",\n",
    "            \"with_statement\",\n",
    "            \"match_statement\"\n",
    "        }\n",
    "    \n",
    "    def parse_code(self, code) -> Tree:\n",
    "        \"\"\"Parse the code into an AST.\"\"\"\n",
    "        tree = self.parser.parse(bytes(code, \"utf8\"))\n",
    "        return tree\n",
    "    \n",
    "    def get_node_code(self, node, code):\n",
    "        \"\"\"Extract the code text for a given node.\"\"\"\n",
    "        start_byte = node.start_byte\n",
    "        end_byte = node.end_byte\n",
    "        return code[start_byte:end_byte]\n",
    "    \n",
    "    def get_node_name(self, node, code):\n",
    "        \"\"\"Try to extract the name of a function or class.\"\"\"\n",
    "        if node.type in (\"function_definition\", \"class_definition\"):\n",
    "            # Find the identifier child\n",
    "            for child in node.children:\n",
    "                if child.type == \"identifier\":\n",
    "                    return self.get_node_code(child, code)\n",
    "        return None\n",
    "    \n",
    "    def get_node_info(self, node, code, parent_type=None):\n",
    "        \"\"\"Get information about a node.\"\"\"\n",
    "        return {\n",
    "            \"type\": node.type,\n",
    "            \"name\": self.get_node_name(node, code),\n",
    "            \"parent_type\": parent_type,\n",
    "            \"code\": self.get_node_code(node, code),\n",
    "            \"start_line\": node.start_point[0],\n",
    "            \"end_line\": node.end_point[0],\n",
    "            \"start_byte\": node.start_byte,\n",
    "            \"end_byte\": node.end_byte,\n",
    "            \"size\": node.end_byte - node.start_byte,\n",
    "        }\n",
    "    \n",
    "    def find_module_imports(self, tree: Tree, code: str) -> list[str]:\n",
    "        \"\"\"Find all import statements in the module.\"\"\"\n",
    "        imports = []\n",
    "        \n",
    "        # Define import node types\n",
    "        import_types = {\"import_statement\", \"import_from_statement\"}\n",
    "        \n",
    "        # Walk through the tree\n",
    "        cursor = tree.walk()\n",
    "        \n",
    "        def visit_node():\n",
    "            node = cursor.node\n",
    "            \n",
    "            assert node is not None\n",
    "            \n",
    "            # If this is an import node, add it to our list\n",
    "            if node.type in import_types:\n",
    "                imports.append(self.get_node_code(node, code))\n",
    "            \n",
    "            # Continue traversal\n",
    "            if cursor.goto_first_child():\n",
    "                visit_node()\n",
    "                while cursor.goto_next_sibling():\n",
    "                    visit_node()\n",
    "                cursor.goto_parent()\n",
    "        \n",
    "        visit_node()\n",
    "        return imports\n",
    "    \n",
    "    def find_chunkable_nodes(self, tree, code):\n",
    "        \"\"\"Find nodes that can be treated as independent chunks.\"\"\"\n",
    "        chunkable_nodes = []\n",
    "        \n",
    "        # Walk through the tree\n",
    "        cursor = tree.walk()\n",
    "        \n",
    "        def visit_node():\n",
    "            node = cursor.node\n",
    "            \n",
    "            # If this is a chunkable node, add it to our list\n",
    "            if node.type in self.chunkable_types:\n",
    "                chunkable_nodes.append(\n",
    "                    self.get_node_info(node, code, parent_type=cursor.node.parent.type if cursor.node.parent else None)\n",
    "                )\n",
    "            \n",
    "            # Continue traversal\n",
    "            if cursor.goto_first_child():\n",
    "                visit_node()\n",
    "                while cursor.goto_next_sibling():\n",
    "                    visit_node()\n",
    "                cursor.goto_parent()\n",
    "        \n",
    "        visit_node()\n",
    "        return chunkable_nodes\n",
    "    \n",
    "    def break_large_node(self, node, max_size):\n",
    "        \"\"\"Break a large node into smaller chunks based on lines.\"\"\"\n",
    "        node_code = node[\"code\"]\n",
    "        lines = node_code.splitlines()\n",
    "        \n",
    "        chunks = []\n",
    "        current_lines = []\n",
    "        current_size = 0\n",
    "        \n",
    "        for line in lines:\n",
    "            line_size = len(line) + 1  # +1 for newline\n",
    "            \n",
    "            # If adding this line would exceed max chunk size, finalize current chunk\n",
    "            if current_size + line_size > max_size and current_lines:\n",
    "                chunks.append({\n",
    "                    \"parent_node\": node,\n",
    "                    \"lines\": current_lines.copy(),\n",
    "                    \"size\": current_size\n",
    "                })\n",
    "                current_lines = []\n",
    "                current_size = 0\n",
    "            \n",
    "            # Add line to current chunk\n",
    "            current_lines.append(line)\n",
    "            current_size += line_size\n",
    "        \n",
    "        # Add any remaining lines as the final chunk\n",
    "        if current_lines:\n",
    "            chunks.append({\n",
    "                \"parent_node\": node,\n",
    "                \"lines\": current_lines,\n",
    "                \"size\": current_size\n",
    "            })\n",
    "        \n",
    "        return chunks\n",
    "    \n",
    "    def create_chunks(self, code):\n",
    "        \"\"\"Break the code into semantic chunks based on the AST.\"\"\"\n",
    "        # Parse the code to get the AST\n",
    "        tree = self.parse_code(code)\n",
    "        \n",
    "        # Get imports\n",
    "        imports = self.find_module_imports(tree, code)\n",
    "        imports_text = \"\\n\".join(imports)\n",
    "        imports_size = len(imports_text) + 2 if imports_text else 0  # +2 for newlines if imports exist\n",
    "        \n",
    "        # Find chunkable nodes\n",
    "        nodes = self.find_chunkable_nodes(tree, code)\n",
    "        \n",
    "        # First, identify oversized nodes that need special handling\n",
    "        regular_nodes = []\n",
    "        line_chunked_nodes = []\n",
    "        \n",
    "        for node in nodes:\n",
    "            # Check if the node alone would exceed our limit\n",
    "            if node[\"size\"] > self.max_chunk_size:\n",
    "                # This node needs to be broken down\n",
    "                sub_chunks = self.break_large_node(node, self.max_chunk_size)\n",
    "                line_chunked_nodes.extend(sub_chunks)\n",
    "            else:\n",
    "                regular_nodes.append(node)\n",
    "        \n",
    "        # Sort regular nodes by size (smallest first to maximize packing)\n",
    "        regular_nodes.sort(key=lambda x: x[\"size\"])\n",
    "        \n",
    "        # Group regular nodes into chunks based on size\n",
    "        semantic_chunks = []\n",
    "        current_chunk = []\n",
    "        current_size = 0\n",
    "        \n",
    "        # Consider imports only for the first chunk\n",
    "        first_chunk_imports_size = imports_size if imports_text else 0\n",
    "        \n",
    "        for node in regular_nodes:\n",
    "            # For the first chunk only, account for imports size\n",
    "            effective_max_size = self.max_chunk_size\n",
    "            effective_current_size = current_size\n",
    "            \n",
    "            # If this would be the first chunk, account for imports\n",
    "            if not semantic_chunks and not current_chunk:\n",
    "                effective_current_size += first_chunk_imports_size\n",
    "            \n",
    "            # If adding this node would exceed max chunk size, finalize current chunk\n",
    "            if effective_current_size + node[\"size\"] > effective_max_size and current_chunk:\n",
    "                # For the first chunk, include imports\n",
    "                if not semantic_chunks and imports_text:\n",
    "                    chunk_code = imports_text + \"\\n\\n\" + \"\\n\".join([n[\"code\"] for n in current_chunk])\n",
    "                else:\n",
    "                    chunk_code = \"\\n\".join([n[\"code\"] for n in current_chunk])\n",
    "                \n",
    "                semantic_chunks.append({\n",
    "                    \"nodes\": current_chunk,\n",
    "                    \"size\": len(chunk_code),\n",
    "                    \"code\": chunk_code,\n",
    "                    \"has_imports\": not semantic_chunks and imports_text  # Only first chunk has imports\n",
    "                })\n",
    "                current_chunk = []\n",
    "                current_size = 0\n",
    "            \n",
    "            # Add node to current chunk\n",
    "            current_chunk.append(node)\n",
    "            current_size += node[\"size\"]\n",
    "        \n",
    "        # Add any remaining nodes as the final chunk\n",
    "        if current_chunk:\n",
    "            if not semantic_chunks and imports_text:\n",
    "                chunk_code = imports_text + \"\\n\\n\" + \"\\n\".join([n[\"code\"] for n in current_chunk])\n",
    "            else:\n",
    "                chunk_code = \"\\n\".join([n[\"code\"] for n in current_chunk])\n",
    "            \n",
    "            semantic_chunks.append({\n",
    "                \"nodes\": current_chunk,\n",
    "                \"size\": len(chunk_code),\n",
    "                \"code\": chunk_code,\n",
    "                \"has_imports\": not semantic_chunks and imports_text  # Only first chunk has imports\n",
    "            })\n",
    "        \n",
    "        # Now handle line-chunked nodes\n",
    "        line_based_chunks = []\n",
    "        for chunked_node in line_chunked_nodes:\n",
    "            node_code = \"\\n\".join(chunked_node[\"lines\"])\n",
    "            \n",
    "            # Only include imports if this would be the first chunk overall\n",
    "            if not semantic_chunks and not line_based_chunks and imports_text:\n",
    "                chunk_code = imports_text + \"\\n\\n\" + node_code\n",
    "                has_imports = True\n",
    "            else:\n",
    "                chunk_code = node_code\n",
    "                has_imports = False\n",
    "            \n",
    "            # Get the original node info but with just the subset of code\n",
    "            parent_node = chunked_node[\"parent_node\"]\n",
    "            line_based_chunks.append({\n",
    "                \"nodes\": [{\n",
    "                    \"type\": parent_node[\"type\"],\n",
    "                    \"name\": parent_node[\"name\"],\n",
    "                    \"parent_type\": parent_node[\"parent_type\"],\n",
    "                    \"code\": node_code,\n",
    "                    \"start_line\": parent_node[\"start_line\"],\n",
    "                    \"end_line\": parent_node[\"end_line\"],\n",
    "                    \"start_byte\": parent_node[\"start_byte\"],\n",
    "                    \"end_byte\": parent_node[\"end_byte\"],\n",
    "                    \"size\": len(node_code),\n",
    "                    \"is_partial\": True,\n",
    "                }],\n",
    "                \"size\": len(chunk_code),\n",
    "                \"code\": chunk_code,\n",
    "                \"has_imports\": has_imports\n",
    "            })\n",
    "        \n",
    "        # Combine both types of chunks\n",
    "        all_chunks = []\n",
    "        \n",
    "        # Ensure the chunk with imports comes first if it exists\n",
    "        import_chunks = [c for c in semantic_chunks + line_based_chunks if c.get(\"has_imports\")]\n",
    "        non_import_chunks = [c for c in semantic_chunks + line_based_chunks if not c.get(\"has_imports\")]\n",
    "        \n",
    "        all_chunks = import_chunks + non_import_chunks\n",
    "        \n",
    "        # Final verification - ensure no chunk exceeds 5000 bytes\n",
    "        for i, chunk in enumerate(all_chunks):\n",
    "            if chunk[\"size\"] > 5000:\n",
    "                print(f\"Warning: Chunk {i} is {chunk['size']} bytes, which exceeds the 5000 byte limit.\")\n",
    "        \n",
    "        # Add a fallback for files with no chunks but valid code\n",
    "        if not all_chunks and code.strip():\n",
    "            # Create a single chunk with the entire file\n",
    "            chunk_code = code.strip()\n",
    "            all_chunks.append({\n",
    "                \"nodes\": [],\n",
    "                \"size\": len(chunk_code),\n",
    "                \"code\": chunk_code,\n",
    "                \"has_imports\": True \n",
    "            })\n",
    "        \n",
    "        return all_chunks\n",
    "    \n",
    "    def create_embeddings_ready_chunks(self, file_path, code, include_metadata=True):\n",
    "        \"\"\"Create chunks with metadata ready for embedding and storage in a vector DB.\"\"\"\n",
    "        chunks = self.create_chunks(code)\n",
    "        \n",
    "        result = []\n",
    "        for i, chunk in enumerate(chunks):\n",
    "            if include_metadata:\n",
    "                # Check if any nodes are partial chunks\n",
    "                partial_info = any(n.get(\"is_partial\", False) for n in chunk[\"nodes\"])\n",
    "                \n",
    "                result.append({\n",
    "                    \"chunk_id\": str(f\"{file_path}_{i}\"),\n",
    "                    \"code\": chunk[\"code\"],\n",
    "                    \"metadata\": {\n",
    "                        \"file_path\": file_path,\n",
    "                        \"size\": chunk[\"size\"],\n",
    "                        # \"node_types\": str([n[\"type\"] for n in chunk[\"nodes\"]]),\n",
    "                        # \"node_names\": str([n[\"name\"] for n in chunk[\"nodes\"] if n[\"name\"]]),\n",
    "                        # \"node_count\": len(chunk[\"nodes\"]),\n",
    "                        # \"has_imports\": chunk.get(\"has_imports\", False),\n",
    "                        # \"is_partial_chunk\": partial_info,\n",
    "                        # \"line_range\": str([\n",
    "                        #     min([n[\"start_line\"] for n in chunk[\"nodes\"]]),\n",
    "                        #     max([n[\"end_line\"] for n in chunk[\"nodes\"]])\n",
    "                        # ])\n",
    "                    }\n",
    "                })\n",
    "            else:\n",
    "                result.append(chunk[\"code\"])\n",
    "        \n",
    "        return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GitHubRepoProcessor:\n",
    "    def __init__(self, owner, repo, token=None):\n",
    "        \"\"\"Initialize with GitHub repository details.\"\"\"\n",
    "        self.owner = owner\n",
    "        self.repo = repo\n",
    "        self.base_url = f\"https://api.github.com/repos/{owner}/{repo}\"\n",
    "        self.headers = {}\n",
    "        \n",
    "        if token:\n",
    "            self.headers[\"Authorization\"] = f\"token {token}\"\n",
    "        \n",
    "        # Add required GitHub API headers\n",
    "        self.headers[\"Accept\"] = \"application/vnd.github.v3+json\"\n",
    "        self.headers[\"X-GitHub-Api-Version\"] = \"2022-11-28\"\n",
    "    \n",
    "    def get_file_list(self, branch=\"main\", path=\"\"):\n",
    "        \"\"\"Get a list of files in the repository, recursively.\"\"\"\n",
    "        all_files = []\n",
    "        self._get_contents_recursive(branch, path, all_files)\n",
    "        return all_files\n",
    "    \n",
    "    def _get_contents_recursive(self, branch, path, all_files):\n",
    "        \"\"\"Recursively fetch repository contents.\"\"\"\n",
    "        url = f\"{self.base_url}/contents/{path}\"\n",
    "        if path == \"\":\n",
    "            url = f\"{self.base_url}/contents\"\n",
    "        \n",
    "        response = requests.get(url, headers=self.headers, params={\"ref\": branch})\n",
    "        \n",
    "        if response.status_code != 200:\n",
    "            print(f\"Error fetching contents at {path}: {response.status_code}\")\n",
    "            print(response.json().get(\"message\", \"\"))\n",
    "            return\n",
    "        \n",
    "        contents = response.json()\n",
    "        \n",
    "        # Handle case where response is a file not a directory\n",
    "        if not isinstance(contents, list):\n",
    "            contents = [contents]\n",
    "        \n",
    "        for item in contents:\n",
    "            if item[\"type\"] == \"file\" and item[\"name\"].endswith(\".py\"):\n",
    "                all_files.append({\n",
    "                    \"path\": item[\"path\"],\n",
    "                    \"download_url\": item[\"download_url\"],\n",
    "                    \"sha\": item[\"sha\"],\n",
    "                    \"size\": item[\"size\"]\n",
    "                })\n",
    "            elif item[\"type\"] == \"dir\":\n",
    "                self._get_contents_recursive(branch, item[\"path\"], all_files)\n",
    "    \n",
    "    def get_file_content(self, branch, file_path):\n",
    "        \"\"\"Get the content of a specific file.\"\"\"\n",
    "        url = f\"{self.base_url}/contents/{file_path}\"\n",
    "        response = requests.get(url, headers=self.headers, params={\"ref\": branch})\n",
    "        \n",
    "        if response.status_code != 200:\n",
    "            print(f\"Error fetching file {file_path}: {response.status_code}\")\n",
    "            return None\n",
    "        \n",
    "        content_data = response.json()\n",
    "        \n",
    "        if \"content\" not in content_data:\n",
    "            print(f\"File {file_path} is too large for the GitHub API. Getting it directly...\")\n",
    "            # Get the raw file directly\n",
    "            raw_response = requests.get(content_data.get(\"download_url\", \"\"))\n",
    "            if raw_response.status_code == 200:\n",
    "                return raw_response.text\n",
    "            else:\n",
    "                print(f\"Failed to get raw content for {file_path}\")\n",
    "                return None\n",
    "        \n",
    "        # Decode content from base64\n",
    "        content = base64.b64decode(content_data[\"content\"]).decode(\"utf-8\")\n",
    "        return content\n",
    "    \n",
    "    def get_diff_files(self, branch1, branch2) -> dict:\n",
    "        \"\"\"Get the diff between two branches.\"\"\"\n",
    "        url = f\"{self.base_url}/compare/{branch1}...{branch2}\"\n",
    "        response = requests.get(url, headers=self.headers)\n",
    "        return {file['filename']: file['status'] for file in response.json()['files']}\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_repo_file(chunker: CodeChunker, github_processor: GitHubRepoProcessor, file_path: str, branch=\"main\"):\n",
    "    code = github_processor.get_file_content(branch, file_path)\n",
    "    if not code or not code.strip():\n",
    "        return []\n",
    "    file_chunks = chunker.create_embeddings_ready_chunks(file_path, code)\n",
    "    return file_chunks\n",
    "\n",
    "def process_repo_files(chunker: CodeChunker, github_processor: GitHubRepoProcessor, branch=\"main\"):\n",
    "    \"\"\"Process all Python files in the repository and generate chunks.\"\"\"\n",
    "    # Get list of Python files\n",
    "    python_files = github_processor.get_file_list()\n",
    "    print(f\"Found {len(python_files)} Python files.\")\n",
    "    \n",
    "    all_chunks = []\n",
    "    \n",
    "    for file_info in tqdm(python_files, desc=\"Processing files\"):\n",
    "        file_path = file_info[\"path\"]\n",
    "        \n",
    "        try:\n",
    "            file_chunks = process_repo_file(chunker, github_processor, file_path, branch)\n",
    "            all_chunks.extend(file_chunks)\n",
    "        except Exception as e:\n",
    "            print(f\"Error processing {file_path}: {str(e)}\")\n",
    "    \n",
    "    print(f\"Generated {len(all_chunks)} chunks in total.\")\n",
    "    return all_chunks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate Chunks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 27 Python files.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing files: 100%|██████████| 27/27 [00:09<00:00,  2.77it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated 57 chunks in total.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "chunker = CodeChunker(language=\"python\", max_chunk_size=500)\n",
    "gh_processor = GitHubRepoProcessor(REPO_OWNER, REPO_NAME, os.environ[\"GITHUB_API_KEY\"])\n",
    "chunks = process_repo_files(chunker=chunker, github_processor=gh_processor, branch=\"main\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below are helper functions to help fork a collection and add chunks to a collection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fork an existing collection\n",
    "def get_or_create_new_fork(client: ClientAPI, existing_collection: Collection, new_name: str) -> Collection:\n",
    "    try:\n",
    "        return existing_collection.fork(new_name)\n",
    "    except Exception as e:\n",
    "        if client.get_collection(new_name) is not None:\n",
    "            return client.get_collection(new_name)\n",
    "        else:\n",
    "            raise e\n",
    "\n",
    "# add chunks to a collection\n",
    "def add_chunks(collection: Collection, chunks: list[dict]):\n",
    "    # only add 100 chunks at a time\n",
    "    for i in range(0, len(chunks), 100):\n",
    "        collection.add(\n",
    "            ids=[chunk[\"chunk_id\"] for chunk in chunks[i:i+100]],\n",
    "            documents=[chunk[\"code\"] for chunk in chunks[i:i+100]],\n",
    "            metadatas=[chunk[\"metadata\"] for chunk in chunks[i:i+100]]\n",
    "        )\n",
    "\n",
    "# populate a collection with the diff files by deleting the existing chunks, regenerating chunks & adding\n",
    "# the diff dictionary looks like this:\n",
    "# {'file_path': 'status'}\n",
    "# status can be 'added', 'modified', 'removed'\n",
    "def populate_branch_diff(collection: Collection, diff_files: dict):\n",
    "    collection.delete(where={\"file_path\": {\"$in\": list(diff_files.keys())}})\n",
    "    for file_path, status in diff_files.items():\n",
    "        if status == \"added\" or status == \"modified\":\n",
    "            chunks = process_repo_file(chunker, gh_processor, file_path, NEW_BRANCH)\n",
    "            add_chunks(collection, chunks)            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ChromaDB Impl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "client = chromadb.HttpClient(\n",
    "  ssl=True,\n",
    "  host='api.trychroma.com',\n",
    "  tenant='fc152910-6412-4b6b-b67a-4eb229ef50ce',\n",
    "  database='Example Demo',\n",
    "  headers={\n",
    "    'x-chroma-token': os.environ[\"CHROMA_API_KEY\"]\n",
    "  }\n",
    ")\n",
    "  \n",
    "\n",
    "main_collection = client.get_or_create_collection(\n",
    "  name=f\"{REPO_OWNER}_{REPO_NAME}_{EXISTING_BRANCH}\",\n",
    "  configuration={\n",
    "    \"embedding_function\": JinaEmbeddingFunction(\n",
    "      model_name=\"jina-embeddings-v2-base-code\"\n",
    "    )\n",
    "  }\n",
    ")\n",
    "\n",
    "add_chunks(\n",
    "  collection=main_collection,\n",
    "  chunks=chunks\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fork the new branch and make updates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "diff = gh_processor.get_diff_files(EXISTING_BRANCH, NEW_BRANCH)\n",
    "\n",
    "new_branch_collection = get_or_create_new_fork(\n",
    "  client=client,\n",
    "  existing_collection=main_collection,\n",
    "  new_name=f\"{REPO_OWNER}_{REPO_NAME}_{NEW_BRANCH}\"\n",
    ")\n",
    "\n",
    "populate_branch_diff(new_branch_collection, diff)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Query both collections"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                                                                                              metadata                                                                                                                                                                                                                                                                                                                                                                                                                                                                                document  distance\n",
      "id                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      \n",
      "django_web_app/manage.py_2                                       {'file_path': 'django_web_app/manage.py', 'size': 39}                                                                                                                                                                                                                                                                                                                                                                                                                                                     execute_from_command_line(sys.argv)  1.077652\n",
      "django_web_app/django_web_app/__init__.py_0      {'size': 5, 'file_path': 'django_web_app/django_web_app/__init__.py'}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   #init  1.426650\n",
      "django_web_app/users/__init__.py_0                        {'size': 5, 'file_path': 'django_web_app/users/__init__.py'}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   #init  1.426650\n",
      "django_web_app/users/migrations/__init__.py_0  {'file_path': 'django_web_app/users/migrations/__init__.py', 'size': 5}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   #init  1.426650\n",
      "django_web_app/blog/migrations/__init__.py_0    {'file_path': 'django_web_app/blog/migrations/__init__.py', 'size': 5}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   #init  1.426650\n",
      "django_web_app/blog/__init__.py_0                          {'size': 5, 'file_path': 'django_web_app/blog/__init__.py'}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   #init  1.426650\n",
      "django_web_app/media/Files/speech.py_0              {'size': 443, 'file_path': 'django_web_app/media/Files/speech.py'}  import pyspeech as sr\\n\\ntry:\\r\\n        text = r.recognize_google(audio)\\r\\n        print(\"You said : {}\".format(text))\\r\\n    except:\\r\\n        print(\"Sorry could not recognize your voice\")\\nwith sr.Microphone() as source:\\r\\n    print(\"Speak Anything : \")\\r\\n    audio = r.listen(source)\\r\\n\\r\\n    try:\\r\\n        text = r.recognize_google(audio)\\r\\n        print(\"You said : {}\".format(text))\\r\\n    except:\\r\\n        print(\"Sorry could not recognize your voice\")  1.596439\n",
      "django_web_app/manage.py_0                                      {'file_path': 'django_web_app/manage.py', 'size': 441}             import os\\nimport sys\\nfrom django.core.management import execute_from_command_line\\n\\ntry:\\r\\n        from django.core.management import execute_from_command_line\\r\\n    except ImportError as exc:\\r\\n        raise ImportError(\\r\\n            \"Couldn't import Django. Are you sure it's installed and \"\\r\\n            \"available on your PYTHONPATH environment variable? Did you \"\\r\\n            \"forget to activate a virtual environment?\"\\r\\n        ) from exc  1.622361\n",
      "django_web_app/manage.py_1                                      {'file_path': 'django_web_app/manage.py', 'size': 461}  if __name__ == '__main__':\\n    os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'django_web_app.settings')\\n    try:\\n        from django.core.management import execute_from_command_line\\n    except ImportError as exc:\\n        raise ImportError(\\n            \"Couldn't import Django. Are you sure it's installed and \"\\n            \"available on your PYTHONPATH environment variable? Did you \"\\n            \"forget to activate a virtual environment?\"\\n        ) from exc  1.642601\n",
      "django_web_app/users/signals.py_0                        {'file_path': 'django_web_app/users/signals.py', 'size': 419}                                      from django.db.models.signals import post_save\\nfrom django.contrib.auth.models import User\\nfrom django.dispatch import receiver\\nfrom .models import Profile\\n\\nif created:\\r\\n        Profile.objects.create(user=instance)\\ndef save_profile(sender, instance,created, **kwargs):\\r\\n    instance.profile.save()\\ndef create_profile(sender, instance, created, **kwargs):\\r\\n    if created:\\r\\n        Profile.objects.create(user=instance)  1.646785\n",
      "                                                                                                              metadata                                                                                                                                                                                                                                                                                                                                                                                                                                                                                document  distance\n",
      "id                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      \n",
      "django_web_app/blog/print.py_0                               {'size': 19, 'file_path': 'django_web_app/blog/print.py'}                                                                                                                                                                                                                                                                                                                                                                                                                                                                     print(\"Hello Fork\")  0.107969\n",
      "django_web_app/manage.py_2                                       {'file_path': 'django_web_app/manage.py', 'size': 39}                                                                                                                                                                                                                                                                                                                                                                                                                                                     execute_from_command_line(sys.argv)  1.077652\n",
      "django_web_app/django_web_app/__init__.py_0      {'file_path': 'django_web_app/django_web_app/__init__.py', 'size': 5}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   #init  1.426650\n",
      "django_web_app/users/__init__.py_0                        {'file_path': 'django_web_app/users/__init__.py', 'size': 5}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   #init  1.426650\n",
      "django_web_app/users/migrations/__init__.py_0  {'file_path': 'django_web_app/users/migrations/__init__.py', 'size': 5}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   #init  1.426650\n",
      "django_web_app/blog/migrations/__init__.py_0    {'size': 5, 'file_path': 'django_web_app/blog/migrations/__init__.py'}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   #init  1.426650\n",
      "django_web_app/blog/__init__.py_0                          {'file_path': 'django_web_app/blog/__init__.py', 'size': 5}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   #init  1.426650\n",
      "django_web_app/media/Files/speech.py_0              {'file_path': 'django_web_app/media/Files/speech.py', 'size': 443}  import pyspeech as sr\\n\\ntry:\\r\\n        text = r.recognize_google(audio)\\r\\n        print(\"You said : {}\".format(text))\\r\\n    except:\\r\\n        print(\"Sorry could not recognize your voice\")\\nwith sr.Microphone() as source:\\r\\n    print(\"Speak Anything : \")\\r\\n    audio = r.listen(source)\\r\\n\\r\\n    try:\\r\\n        text = r.recognize_google(audio)\\r\\n        print(\"You said : {}\".format(text))\\r\\n    except:\\r\\n        print(\"Sorry could not recognize your voice\")  1.596439\n",
      "django_web_app/manage.py_0                                      {'size': 441, 'file_path': 'django_web_app/manage.py'}             import os\\nimport sys\\nfrom django.core.management import execute_from_command_line\\n\\ntry:\\r\\n        from django.core.management import execute_from_command_line\\r\\n    except ImportError as exc:\\r\\n        raise ImportError(\\r\\n            \"Couldn't import Django. Are you sure it's installed and \"\\r\\n            \"available on your PYTHONPATH environment variable? Did you \"\\r\\n            \"forget to activate a virtual environment?\"\\r\\n        ) from exc  1.622361\n",
      "django_web_app/manage.py_1                                      {'size': 461, 'file_path': 'django_web_app/manage.py'}  if __name__ == '__main__':\\n    os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'django_web_app.settings')\\n    try:\\n        from django.core.management import execute_from_command_line\\n    except ImportError as exc:\\n        raise ImportError(\\n            \"Couldn't import Django. Are you sure it's installed and \"\\n            \"available on your PYTHONPATH environment variable? Did you \"\\n            \"forget to activate a virtual environment?\"\\n        ) from exc  1.642601\n"
     ]
    }
   ],
   "source": [
    "# code search both collections \n",
    "query = \"print('Hello, forking!')\"\n",
    "\n",
    "main_results = main_collection.query(\n",
    "  query_texts=[query],\n",
    "  n_results=10\n",
    ")\n",
    "\n",
    "new_branch_results = new_branch_collection.query(\n",
    "  query_texts=[query],\n",
    "  n_results=10\n",
    ")\n",
    "\n",
    "for i, df in enumerate(query_result_to_dfs(main_results)):\n",
    "  print(df.to_string())\n",
    "\n",
    "for i, df in enumerate(query_result_to_dfs(new_branch_results)):\n",
    "  print(df.to_string())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
